{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V2RPZQF05ngq"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "PEig-M385xKS"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9afYQkHI52mr"
      },
      "source": [
        "# Step 2: Train a machine learning model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v8hkClAF58p6"
      },
      "source": [
        "This is the notebook for step 2 of the codelab [**Build a handwritten digit classifier app with TensorFlow Lite**](https://codelabs.developers.google.com/codelabs/digit-classifier-tflite/)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yu1BOwfPzaIy"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/lite/codelabs/digit_classifier/ml/step2_train_ml_model.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/lite/codelabs/digit_classifier/ml/step2_train_ml_model.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EXyJkL4WnqyS"
      },
      "source": [
        "## Import dependencies"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XXX8WpQI5U6_"
      },
      "source": [
        "We start by importing TensorFlow and other supporting libraries that are used for data processing and visualization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kS_mq4yAlXHZ"
      },
      "outputs": [],
      "source": [
        "# TensorFlow and tf.keras\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "# Helper libraries\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import random\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r0WlLrJcnwWp"
      },
      "source": [
        "## Download and explore the MNIST dataset\n",
        "The MNIST database contains 60,000 training images and 10,000 testing images of handwritten digits. We will use the dataset to train our digit classification model.\n",
        "\n",
        "Each image in the MNIST dataset is a 28x28 grayscale image containing a digit from 0 to 9, and a label identifying which digit is in the image.\n",
        "![MNIST sample](https://github.com/khanhlvg/DigitClassifier/raw/master/images/mnist.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B5REuMrblewK"
      },
      "outputs": [],
      "source": [
        "# Keras provides a handy API to download the MNIST dataset, and split them into\n",
        "# \"train\" dataset and \"test\" dataset.\n",
        "mnist = keras.datasets.mnist\n",
        "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "REY5lDDylpFh"
      },
      "outputs": [],
      "source": [
        "# Normalize the input image so that each pixel value is between 0 to 1.\n",
        "train_images = train_images / 255.0\n",
        "test_images = test_images / 255.0\n",
        "print('Pixels are normalized')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SAOE84IplyWR"
      },
      "outputs": [],
      "source": [
        "# Show the first 25 images in the training dataset.\n",
        "plt.figure(figsize=(10,10))\n",
        "for i in range(25):\n",
        "  plt.subplot(5,5,i+1)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  plt.grid(False)\n",
        "  plt.imshow(train_images[i], cmap=plt.cm.gray)\n",
        "  plt.xlabel(train_labels[i])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9v-Wp3TxpLX6"
      },
      "source": [
        "## Train a TensorFlow model to classify digit images\n",
        "\n",
        "Next, we use Keras API to build a TensorFlow model and train it on the MNIST \"train\" dataset. After training, our model will be able to classify the digit images.\n",
        "\n",
        "Our model takes **a 28px x 28px grayscale image** as an input, and outputs **a float array of length 10** representing the probability of the image being a digit from 0 to 9.\n",
        "\n",
        "Here we use a simple convolutional neural network, which is a common technique in computer vision. We will not go into details about model architecture in this codelab. If you want have a deeper understanding about different ML model architectures, please consider taking our free [TensorFlow training course](https://www.coursera.org/learn/introduction-tensorflow)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IWgBGmaplzcp"
      },
      "outputs": [],
      "source": [
        "# Define the model architecture\n",
        "model = keras.Sequential([\n",
        "  keras.layers.InputLayer(input_shape=(28, 28)),\n",
        "  keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
        "  keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "  keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "  keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "  keras.layers.Dropout(0.25),\n",
        "  keras.layers.Flatten(),\n",
        "  keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "# Define how to train the model\n",
        "model.compile(optimizer='adam',\n",
        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "# Train the digit classification model\n",
        "model.fit(train_images, train_labels, epochs=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WFHKkb7gcJei"
      },
      "source": [
        "Let's take a closer look at our model structure."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y7V6-UQqcuK-"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n16JkSyNc5cf"
      },
      "source": [
        "There is an extra dimension with **None** shape in every layer in our model, which is called the **batch dimension**. In machine learning, we usually process data in batches to improve throughput, so TensorFlow automatically add the dimension for you."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "za35DFJ5uFkX"
      },
      "source": [
        "## Evaluate our model\n",
        "We run our digit classification model against our \"test\" dataset that the model has not seen during its training process to confirm that the model did not just remember the digits it saw but also generalize well to new images."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sJI8nqFWmMwC"
      },
      "outputs": [],
      "source": [
        "# Evaluate the model using all images in the test dataset.\n",
        "test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
        "\n",
        "print('Test accuracy:', test_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7-qv9-9_cUb7"
      },
      "source": [
        "Although our model is relatively simple, we were able to achieve good accuracy around 98% on new images that the model has never seen before. Let's visualize the result."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PdlXEyUImeXC"
      },
      "outputs": [],
      "source": [
        "# A helper function that returns 'red'/'black' depending on if its two input\n",
        "# parameter matches or not.\n",
        "def get_label_color(val1, val2):\n",
        "  if val1 == val2:\n",
        "    return 'black'\n",
        "  else:\n",
        "    return 'red'\n",
        "\n",
        "# Predict the labels of digit images in our test dataset.\n",
        "predictions = model.predict(test_images)\n",
        "\n",
        "# As the model output 10 float representing the probability of the input image\n",
        "# being a digit from 0 to 9, we need to find the largest probability value\n",
        "# to find out which digit the model predicts to be most likely in the image.\n",
        "prediction_digits = np.argmax(predictions, axis=1)\n",
        "\n",
        "# Then plot 100 random test images and their predicted labels.\n",
        "# If a prediction result is different from the label provided label in \"test\"\n",
        "# dataset, we will highlight it in red color.\n",
        "plt.figure(figsize=(18, 18))\n",
        "for i in range(100):\n",
        "  ax = plt.subplot(10, 10, i+1)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  plt.grid(False)\n",
        "  image_index = random.randint(0, len(prediction_digits))\n",
        "  plt.imshow(test_images[image_index], cmap=plt.cm.gray)\n",
        "  ax.xaxis.label.set_color(get_label_color(prediction_digits[image_index],\\\n",
        "                                           test_labels[image_index]))\n",
        "  plt.xlabel('Predicted: %d' % prediction_digits[image_index])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AWROBI4iv9fY"
      },
      "source": [
        "## Convert the Keras model to TensorFlow Lite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bV99Izwykb-J"
      },
      "source": [
        "Now as we have trained the digit classifer model, we will convert it to TensorFlow Lite format for mobile deployment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2fXStjR4mzkR"
      },
      "outputs": [],
      "source": [
        "# Convert Keras model to TF Lite format.\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
        "tflite_float_model = converter.convert()\n",
        "\n",
        "# Show model size in KBs.\n",
        "float_model_size = len(tflite_float_model) / 1024\n",
        "print('Float model size = %dKBs.' % float_model_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tfer6hI8ljEh"
      },
      "source": [
        "As we will deploy our model to a mobile device, we want our model to be as small and as fast as possible. **Quantization** is a common technique often used in on-device machine learning to shrink ML models. Here we will use 8-bit number to approximate our 32-bit weights, which in turn shrinks the model size by a factor of 4.\n",
        "\n",
        "See [TensorFlow documentation](https://www.tensorflow.org/lite/performance/post_training_quantization) to learn more about other quantization techniques."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yhY86SRTmtGC"
      },
      "outputs": [],
      "source": [
        "# Re-convert the model to TF Lite using quantization.\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "tflite_quantized_model = converter.convert()\n",
        "\n",
        "# Show model size in KBs.\n",
        "quantized_model_size = len(tflite_quantized_model) / 1024\n",
        "print('Quantized model size = %dKBs,' % quantized_model_size)\n",
        "print('which is about %d%% of the float model size.'\\\n",
        "      % (quantized_model_size * 100 / float_model_size))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ahTP3T60nYJb"
      },
      "source": [
        "## Evaluate the TensorFlow Lite model\n",
        "\n",
        "By using quantization, we often traded off a bit of accuracy for the benefit of having a significantly smaller model. Let's calculate the accuracy drop of our quantized model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YvszGa11ne6Q"
      },
      "outputs": [],
      "source": [
        "# A helper function to evaluate the TF Lite model using \"test\" dataset.\n",
        "def evaluate_tflite_model(tflite_model):\n",
        "  # Initialize TFLite interpreter using the model.\n",
        "  interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "  interpreter.allocate_tensors()\n",
        "  input_tensor_index = interpreter.get_input_details()[0][\"index\"]\n",
        "  output = interpreter.tensor(interpreter.get_output_details()[0][\"index\"])\n",
        "\n",
        "  # Run predictions on every image in the \"test\" dataset.\n",
        "  prediction_digits = []\n",
        "  for test_image in test_images:\n",
        "    # Pre-processing: add batch dimension and convert to float32 to match with\n",
        "    # the model's input data format.\n",
        "    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n",
        "    interpreter.set_tensor(input_tensor_index, test_image)\n",
        "\n",
        "    # Run inference.\n",
        "    interpreter.invoke()\n",
        "\n",
        "    # Post-processing: remove batch dimension and find the digit with highest\n",
        "    # probability.\n",
        "    digit = np.argmax(output()[0])\n",
        "    prediction_digits.append(digit)\n",
        "\n",
        "  # Compare prediction results with ground truth labels to calculate accuracy.\n",
        "  accurate_count = 0\n",
        "  for index in range(len(prediction_digits)):\n",
        "    if prediction_digits[index] == test_labels[index]:\n",
        "      accurate_count += 1\n",
        "  accuracy = accurate_count * 1.0 / len(prediction_digits)\n",
        "\n",
        "  return accuracy\n",
        "\n",
        "# Evaluate the TF Lite float model. You'll find that its accurary is identical\n",
        "# to the original TF (Keras) model because they are essentially the same model\n",
        "# stored in different format.\n",
        "float_accuracy = evaluate_tflite_model(tflite_float_model)\n",
        "print('Float model accuracy = %.4f' % float_accuracy)\n",
        "\n",
        "# Evalualte the TF Lite quantized model.\n",
        "# Don't be surprised if you see quantized model accuracy is higher than\n",
        "# the original float model. It happens sometimes :)\n",
        "quantized_accuracy = evaluate_tflite_model(tflite_quantized_model)\n",
        "print('Quantized model accuracy = %.4f' % quantized_accuracy)\n",
        "print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ItyEwAdCCVw6"
      },
      "source": [
        "## Download the TensorFlow Lite model\n",
        "\n",
        "Let's get our model and integrate it into an Android app.\n",
        "\n",
        "If you see an error when downloading mnist.tflite from Colab, try running this cell again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q_Z5yLxrwbpI"
      },
      "outputs": [],
      "source": [
        "# Save the quantized model to file to the Downloads directory\n",
        "f = open('mnist.tflite', \"wb\")\n",
        "f.write(tflite_quantized_model)\n",
        "f.close()\n",
        "\n",
        "# Download the digit classification model\n",
        "from google.colab import files\n",
        "files.download('mnist.tflite')\n",
        "\n",
        "print('`mnist.tflite` has been downloaded')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C4ASalaLIbu2"
      },
      "source": [
        "## Good job!\n",
        "This is the end of *Step 2: Train a machine learning model* in the codelab **Build a handwritten digit classifier app with TensorFlow Lite**. Let's go back to our codelab and proceed to the [next step](https://codelabs.developers.google.com/codelabs/digit-classifier-tflite/#2)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "step2_train_ml_model.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
